{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction\n",
    "\n",
    "In real world, there exists many graphs contain multiple types of nodes and edges, which we call them Heterogeneous Graphs. Obviously, heterogenous graphs are more complex than homogeneous graphs. \n",
    "\n",
    "To deal with such heterogeneous graphs, PGL develops a graph framework to support graph neural network computations and meta-path-based sampling on heterogenous graph.\n",
    "\n",
    "The goal of this tutorial:\n",
    "* example of heterogenous graph data;\n",
    "* Understand how PGL supports computations in heterogenous graph;\n",
    "* Using PGL to implement a simple heterogenous graph neural network model to classfiy a particular type of node in a heterogenous graph network."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example of heterogenous graph\n",
    "\n",
    "There are a lot of graph data that consists of edges and nodes of multiple types. For example, **e-commerce network** is very common heterogenous graph in real world. It contains at least two types of nodes (user and item) and two types of edges (buy and click). \n",
    "\n",
    "The following figure depicts several users click or buy some items. This graph has two types of nodes corresponding to \"user\" and \"item\". It also contain two types of edge \"buy\" and \"click\".\n",
    "\n",
    "![A simple heterogenous e-commerce graph](./heter_graph_introduction.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating a heterogenous graph with PGL \n",
    "\n",
    "In heterogenous graph, there exists multiple edges, so we should distinguish them. In PGL, the edges are built in below format:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "edges = {\n",
    "        'click': [(0, 4), (0, 7), (1, 6), (2, 5), (3, 6)],\n",
    "        'buy': [(0, 5), (1, 4), (1, 6), (2, 7), (3, 5)],\n",
    "    }\n",
    "\n",
    "clicked = [(j, i) for i, j in edges['click']]\n",
    "bought = [(j, i) for i, j in edges['buy']]\n",
    "edges['clicked'] = clicked\n",
    "edges['bought'] = bought"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In heterogenous graph, nodes are also of different types. Therefore, you need to mark the type of each node, the format of the node type is as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "node_types = [(0, 'user'), (1, 'user'), (2, 'user'), (3, 'user'), (4, 'item'), \n",
    "             (5, 'item'),(6, 'item'), (7, 'item')]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Because of the different types of edges, edge features also need to be separated by different types."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import paddle\n",
    "import paddle.nn as nn\n",
    "import pgl\n",
    "seed = 0\n",
    "np.random.seed(0)\n",
    "paddle.seed(0)\n",
    "\n",
    "num_nodes = len(node_types)\n",
    "\n",
    "node_features = {'features': np.random.randn(num_nodes, 8).astype(\"float32\")}\n",
    "\n",
    "labels = np.array([0, 1, 0, 1, 0, 1, 1, 0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we can build a heterogenous graph by using PGL."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = pgl.HeterGraph(edges=edges, \n",
    "                   node_types=node_types,\n",
    "                   node_feat=node_features)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MessagePassing on Heterogeneous Graph\n",
    "\n",
    "After building the heterogeneous graph, we can easily carry out the message passing mode. In this case, we have two different types of edges."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class HeterMessagePassingLayer(nn.Layer):\n",
    "    def __init__(self, in_dim, out_dim, etypes):\n",
    "        super(HeterMessagePassingLayer, self).__init__()\n",
    "        self.in_dim = in_dim\n",
    "        self.out_dim = out_dim\n",
    "        self.etypes = etypes\n",
    "        \n",
    "        self.weight = []\n",
    "        for i in range(len(self.etypes)):\n",
    "            self.weight.append(\n",
    "                 self.create_parameter(shape=[self.in_dim, self.out_dim]))\n",
    "        \n",
    "    def forward(self, graph, feat):\n",
    "        def send_func(src_feat, dst_feat, edge_feat):\n",
    "            return src_feat\n",
    "        \n",
    "        def recv_func(msg):\n",
    "            return msg.reduce_mean(msg[\"h\"])\n",
    "        \n",
    "        feat_list = []\n",
    "        for idx, etype in enumerate(self.etypes):\n",
    "            h = paddle.matmul(feat, self.weight[idx])\n",
    "            msg = graph[etype].send(send_func, src_feat={\"h\": h})\n",
    "            h = graph[etype].recv(recv_func, msg)\n",
    "            feat_list.append(h)\n",
    "            \n",
    "        h = paddle.stack(feat_list, axis=0)\n",
    "        h = paddle.sum(h, axis=0)\n",
    "        \n",
    "        return h\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a simple GNN by stacking two HeterMessagePassingLayer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class HeterGNN(nn.Layer):\n",
    "    def __init__(self, in_dim, hidden_size, etypes, num_class):\n",
    "        super(HeterGNN, self).__init__()\n",
    "        self.in_dim = in_dim\n",
    "        self.hidden_size = hidden_size\n",
    "        self.etypes = etypes\n",
    "        self.num_class = num_class\n",
    "        \n",
    "        self.layers = nn.LayerList()\n",
    "        self.layers.append(\n",
    "                HeterMessagePassingLayer(self.in_dim, self.hidden_size, self.etypes))\n",
    "        self.layers.append(\n",
    "                HeterMessagePassingLayer(self.hidden_size, self.hidden_size, self.etypes))\n",
    "        \n",
    "        self.linear = nn.Linear(self.hidden_size, self.num_class)\n",
    "        \n",
    "    def forward(self, graph, feat):\n",
    "        h = feat\n",
    "        for i in range(len(self.layers)):\n",
    "            h = self.layers[i](graph, h)\n",
    "            \n",
    "        logits = self.linear(h)\n",
    "        \n",
    "        return logits\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = HeterGNN(8, 8, g.edge_types, 2)\n",
    "\n",
    "criterion = paddle.nn.loss.CrossEntropyLoss()\n",
    "\n",
    "optim = paddle.optimizer.Adam(learning_rate=0.05, \n",
    "                            parameters=model.parameters())\n",
    "\n",
    "g.tensor()\n",
    "labels = paddle.to_tensor(labels)\n",
    "for epoch in range(10):\n",
    "    #print(g.node_feat[\"features\"])\n",
    "    logits = model(g, g.node_feat[\"features\"])\n",
    "    loss = criterion(logits, labels)\n",
    "    loss.backward()\n",
    "    optim.step()\n",
    "    optim.clear_grad()\n",
    "    \n",
    "    print(\"epoch: %s | loss: %.4f\" % (epoch, loss.numpy()[0]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
